{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "693afa39",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from data_loading import sine_data_generation\n",
    "from utils import random_generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f2248e6",
   "metadata": {},
   "source": [
    "EIN Batch (=128 samples) Sine Daten Generieren"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5d350d1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "no, seq_len, dim = 128, 24, 5 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3ab77f3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = sine_data_generation(no, seq_len, dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6a6e8abc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.54205624, 0.54337448, 0.50690871, 0.52631528, 0.501343  ],\n",
       "       [0.5859264 , 0.5916374 , 0.54355345, 0.55005655, 0.5098356 ],\n",
       "       [0.62912402, 0.63903061, 0.57996366, 0.57368429, 0.51832536],\n",
       "       [0.67131102, 0.68510434, 0.61594325, 0.59714494, 0.52680983],\n",
       "       [0.71215719, 0.7294213 , 0.65129848, 0.62038528, 0.53528657],\n",
       "       [0.75134285, 0.77156091, 0.68583896, 0.6433526 , 0.54375312],\n",
       "       [0.7885613 , 0.81112323, 0.71937868, 0.66599483, 0.55220705],\n",
       "       [0.82352123, 0.8477328 , 0.75173702, 0.68826061, 0.56064592],\n",
       "       [0.85594902, 0.88104215, 0.78273975, 0.71009945, 0.56906729],\n",
       "       [0.88559087, 0.91073517, 0.81221989, 0.73146181, 0.57746873],\n",
       "       [0.91221477, 0.93653004, 0.84001871, 0.75229927, 0.58584782],\n",
       "       [0.93561234, 0.95818197, 0.8659865 , 0.77256455, 0.59420214],\n",
       "       [0.95560045, 0.97548546, 0.88998342, 0.7922117 , 0.60252928],\n",
       "       [0.97202267, 0.98827628, 0.91188025, 0.81119617, 0.61082683],\n",
       "       [0.98475045, 0.99643304, 0.93155908, 0.82947489, 0.61909241],\n",
       "       [0.99368418, 0.99987834, 0.94891392, 0.84700643, 0.62732363],\n",
       "       [0.99875394, 0.99857946, 0.96385132, 0.86375102, 0.6355181 ],\n",
       "       [0.99992004, 0.99254875, 0.97629085, 0.87967068, 0.64367348],\n",
       "       [0.99717336, 0.98184344, 0.98616551, 0.89472931, 0.6517874 ],\n",
       "       [0.9905354 , 0.96656512, 0.99342214, 0.90889277, 0.65985753],\n",
       "       [0.98005812, 0.94685879, 0.99802164, 0.92212893, 0.66788153],\n",
       "       [0.9658235 , 0.92291149, 0.99993926, 0.93440777, 0.6758571 ],\n",
       "       [0.94794298, 0.89495048, 0.99916467, 0.94570146, 0.68378192],\n",
       "       [0.92655649, 0.86324115, 0.99570203, 0.95598437, 0.69165372]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0cba6bf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b620dd1a",
   "metadata": {},
   "source": [
    "Klassen für einzelne Netzwerkkomponenten"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "5b0a1ee2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Time_GAN_module(nn.Module):\n",
    "    \"\"\"\n",
    "    Class from which a module of the Time GAN Architecture can be constructed, \n",
    "    consisting of a n_layer stacked RNN layers and a fully connected layer\n",
    "    \n",
    "    input_size = dim of data (depending if module operates on latent or non-latent space)\n",
    "    \"\"\"\n",
    "    def __init__(self, input_size, output_size, hidden_dim, n_layers):\n",
    "        super(Time_GAN_module, self).__init__()\n",
    "\n",
    "        # Parameters\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.n_layers = n_layers\n",
    "\n",
    "        #Defining the layers\n",
    "        # RNN Layer\n",
    "        self.rnn = nn.GRU(input_size, hidden_dim, n_layers, batch_first=True)   \n",
    "        # Fully connected layer\n",
    "        self.fc = nn.Linear(hidden_dim, output_size)\n",
    "        \n",
    "    def forward(self, x):\n",
    "    \n",
    "            batch_size = x.size(0)\n",
    "\n",
    "            # Initializing hidden state for first input using method defined below\n",
    "            hidden = self.init_hidden(batch_size)\n",
    "\n",
    "            # Passing in the input and hidden state into the model and obtaining outputs\n",
    "            out, hidden = self.rnn(x, hidden)\n",
    "        \n",
    "            # Reshaping the outputs such that it can be fit into the fully connected layer\n",
    "            out = out.contiguous().view(-1, self.hidden_dim)\n",
    "            out = self.fc(out)\n",
    "            \n",
    "            # SIGMOID HINZUFÜGEN\n",
    "            # HIDDEN STATES WERDEN IN DER PAPER IMPLEMENTIERUNG AUCH COMPUTED, ALLERDINGS NICHT BENUTZT?\n",
    "            \n",
    "            return out, hidden\n",
    "    \n",
    "    def init_hidden(self, batch_size):\n",
    "        # This method generates the first hidden state of zeros which we'll use in the forward pass\n",
    "        # We'll send the tensor holding the hidden state to the device we specified earlier as well\n",
    "        hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim)\n",
    "        return hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ce0cf93",
   "metadata": {},
   "source": [
    "Embedder & Recovery"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e51e29b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.Tensor(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4dbb5be1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 24, 5])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ca4aa6cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_size = 5\n",
    "output_size = 20\n",
    "hidden_dim = 20\n",
    "n_layers = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "1da74c2f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Time_GAN_module(\n",
       "  (rnn): GRU(5, 20, num_layers=3, batch_first=True)\n",
       "  (fc): Linear(in_features=20, out_features=20, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Embedder = Time_GAN_module(input_size=dim, output_size=hidden_dim, hidden_dim=hidden_dim, n_layers=n_layers)\n",
    "Embedder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "c0183e0f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 24, 20])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "H, _ = Embedder(x)\n",
    "\n",
    "H = torch.reshape(H, (128, 24, 20))\n",
    "\n",
    "H.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "740539ab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Time_GAN_module(\n",
       "  (rnn): GRU(20, 20, num_layers=3, batch_first=True)\n",
       "  (fc): Linear(in_features=20, out_features=5, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Recovery = Time_GAN_module(input_size=H.shape[2], output_size=dim, hidden_dim=hidden_dim, n_layers=n_layers)\n",
    "Recovery"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "d9ad3b5c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 24, 5])"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_tilde = Recovery(H)[0]\n",
    "\n",
    "x_tilde = torch.reshape(x_tilde, (128, 24, 5))\n",
    "\n",
    "x_tilde.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbb3f88a",
   "metadata": {},
   "source": [
    "Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "66b6c549",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import extract_time\n",
    "\n",
    "time_information, max_seq_len = extract_time(data)\n",
    "\n",
    "# time_information pis a list with each element the number of time points for each sample "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "29f54e7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_data = random_generator(batch_size=128, z_dim=dim, T_mb=time_information, max_seq_len=max_seq_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "1add867e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 24, 5])"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z = torch.tensor(random_data)\n",
    "z = z.float()\n",
    "\n",
    "z.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "8f4f9fa9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Time_GAN_module(\n",
       "  (rnn): GRU(5, 20, num_layers=3, batch_first=True)\n",
       "  (fc): Linear(in_features=20, out_features=20, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Generator = Time_GAN_module(input_size=dim, output_size=hidden_dim, hidden_dim=hidden_dim, n_layers=n_layers)\n",
    "Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "9dab4fc7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 24, 20])"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "e_hat, _ = Embedder(z)\n",
    "\n",
    "e_hat = torch.reshape(e_hat, (128, 24, 20))\n",
    "\n",
    "\n",
    "e_hat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "adc8f90c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Time_GAN_module(\n",
       "  (rnn): GRU(20, 20, num_layers=2, batch_first=True)\n",
       "  (fc): Linear(in_features=20, out_features=20, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Supervisor = Time_GAN_module(input_size=e_hat.shape[2], output_size=hidden_dim, hidden_dim=hidden_dim, n_layers=n_layers-1)\n",
    "Supervisor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "f36592b3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 24, 20])"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "H_hat, _ = Supervisor(e_hat)\n",
    "\n",
    "H_hat = torch.reshape(H_hat, (128, 24, 20))\n",
    "\n",
    "H_hat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "8367a453",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 24, 20])"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "H_hat_supervise, _ = Supervisor(H)\n",
    "   \n",
    "H_hat_supervise = torch.reshape(H_hat_supervise, (128, 24, 20))    \n",
    "    \n",
    "H_hat_supervise.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f45a8f1",
   "metadata": {},
   "source": [
    "Recovery"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "13a5231a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 24, 5])"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_hat, _ = Recovery(H_hat)\n",
    "\n",
    "x_hat = torch.reshape(x_hat, (128, 24, 5))\n",
    "\n",
    "x_hat.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d27658d",
   "metadata": {},
   "source": [
    "Discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "84b82b4c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Time_GAN_module(\n",
       "  (rnn): GRU(20, 20, num_layers=3, batch_first=True)\n",
       "  (fc): Linear(in_features=20, out_features=1, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 101,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Discriminator = Time_GAN_module(input_size=H.shape[2], output_size=1, hidden_dim=hidden_dim, n_layers=n_layers)\n",
    "Discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "09a42942",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 24, 1])"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_fake, _ = Discriminator(H_hat)\n",
    "\n",
    "Y_fake = torch.reshape(Y_fake, (128, 24, 1))\n",
    "\n",
    "Y_fake.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "88b70d43",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 24, 1])"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_real, _ = Discriminator(H)\n",
    "\n",
    "Y_real = torch.reshape(Y_real, (128, 24, 1))\n",
    "\n",
    "Y_real.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "131f4142",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 24, 1])"
      ]
     },
     "execution_count": 107,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_fake_e, _ = Discriminator(H)\n",
    "\n",
    "Y_fake_e = torch.reshape(Y_fake_e, (128, 24, 1))\n",
    "\n",
    "Y_fake_e.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "torch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
